import os
import sys
from collections import defaultdict
import pysam
from numpy import *
from Bio.Seq import reverse_complement
from Bio import SeqIO


dataset, library = sys.argv[1:]

annotations = ("rRNA", "tRNA", "snRNA", "snoRNA", "scRNA", "scaRNA", "vRNA",
               "yRNA", "snar", "histone", "chrM", "RMRP", "RPPH",
              )

def read_scrna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "scRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    biogenesis = {}
    for record in records:
        description, category = record.description.rsplit(", ", 1)
        name, description = description.split(None, 1)
        assert name == record.id
        assert description.endswith(")")
        index = description.rindex("(")
        gene = description[index+1:-1]
        description = description[:index].rstrip()
        if category == "long non-coding RNA":
            assert description == "Homo sapiens brain cytoplasmic RNA 1"
            assert gene == "BCYRN1"
            biogenesis[record.id] = ("brain cytoplasmic RNA 1", "Pol-III")
        elif category == "non-coding RNA":
            assert description == "Homo sapiens RNA, 7SL, cytoplasmic 832, pseudogene"
            assert gene == "RN7SL832P"
            biogenesis[record.id] = ("7SL", "Pol-III")
        else:
            assert category == "small cytoplasmic RNA"
            prefix = "Homo sapiens RNA component of signal recognition particle 7SL"
            if description.startswith(prefix):
                assert gene in ("RN7SL1", "RN7SL2", "RN7SL3")
                biogenesis[record.id] = ("7SL", "Pol-III")
            elif description == "Homo sapiens MALAT1-associated small cytoplasmic RNA":
                assert gene == "MASCRNA"
                biogenesis[record.id] = (gene, "intronic")
            else:
                raise Exception("Unknown gene '%s' with description '%s'" % (gene, description))
    return biogenesis

def read_snrna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "snRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    snrna_annotations = {}
    for record in records:
        if record.id.startswith("ENST"):
            assert record.description.startswith(record.id)
            description = record.description[len(record.id)+1:]
            name, description, rfam = description.rsplit("|")
            snRNA = None
            biogenesis = None
            if rfam == "U1 spliceosomal RNA":
                snRNA = "U1"
            elif rfam == "U2 spliceosomal RNA":
                snRNA = "U2"
            elif rfam == "U4 spliceosomal RNA":
                snRNA = "U4"
            elif rfam == "U4atac minor spliceosomal RNA":
                snRNA = "U4atac"
            elif rfam == "U5 spliceosomal RNA":
                snRNA = "U5"
            elif rfam == "U6 spliceosomal RNA":
                snRNA = "U6"
            elif rfam == "U6atac minor spliceosomal RNA":
                snRNA = "U6atac"
            elif rfam == "U7 small nuclear RNA":
                snRNA = "U7"
            elif rfam == "U11 spliceosomal RNA":
                snRNA = "U11"
            elif rfam == "U12 minor spliceosomal RNA":
                snRNA = "U12"
            elif description == "Homo sapiens (human) U1 spliceosomal RNA":
                snRNA = "U1"
            elif description == "Homo sapiens U6 spliceosomal RNA":
                snRNA = "U6"
            elif description == "Homo sapiens (human) U6 spliceosomal RNA":
                snRNA = "U6"
            elif description == "Homo sapiens (human) U6 spliceosomal RNA (multiple genes)":
                snRNA = "U6"
            elif description.startswith("Homo sapiens (human) U1 spliceosomal RNA"):
                snRNA = "U1"
            elif (not description and not rfam):
                if name.startswith("RNU4ATAC"):
                    snRNA = "U4atac"
                elif name.startswith("RNU6ATAC"):
                    snRNA = "U6atac"
                elif name.startswith("RNA, U5A small nuclear "):
                    snRNA = "U5"
                elif name.startswith("RNA, U6 small nuclear "):
                    snRNA = "U6"
                else:
                    try:
                        gene, number = name.split("-")
                    except ValueError:
                        pass
                    else:
                        if gene == "RNU1":
                            snRNA = "U1"
                        elif gene == "RNU2":
                            snRNA = "U2"
                        elif gene == "RNU4":
                            snRNA = "U4"
                        elif gene in ("RNU5A", "RNU5B", "RNU5D", "RNU5F"):
                            snRNA = "U5"
                        elif gene == "RNU6":
                            snRNA = "U6"
                        elif gene == "RNU7":
                            snRNA = "U7"
            if snRNA is None:
                if record.id == 'ENST00000620626.1':
                    snRNA = "U4"
                elif record.id == 'ENST00000636749.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000636931.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000636425.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000620349.1':
                    snRNA = "U4"
                elif record.id == 'ENST00000637085.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000516584.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000516940.1':
                    snRNA = "U7"
                elif record.id == 'ENST00000637295.1':
                    snRNA = "U2"
                elif record.id == 'ENST00000636829.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000619968.1':
                    snRNA = "U7"
                elif record.id == 'ENST00000618345.1':
                    snRNA = "U4atac"
                elif record.id == 'ENST00000516898.1':
                    snRNA = "U11"
                elif record.id == 'ENST00000619194.1':
                    snRNA = "U6"
                elif record.id == 'ENST00000647487.1':
                    snRNA = "U7"
                elif record.id == 'ENST00000646220.1':
                    snRNA = "U6"
                elif record.id == "ENST00000516898.1":
                    snRNA = "U11"
        else:
            description, category = record.description.rsplit(", ", 1)
            name, description = description.split(None, 1)
            assert name == record.id
            assert description.endswith(")")
            index = description.rindex("(")
            gene = description[index+1:-1]
            description = description[:index].rstrip()
            if category == "long non-coding RNA":
                assert description == "Homo sapiens brain cytoplasmic RNA 1"
                assert gene == "BCYRN1"
                snRNA = gene
                biogenesis = "Pol-III"
            elif category == "non-coding RNA":
                assert description == "Homo sapiens RNA, 7SL, cytoplasmic 832, pseudogene"
                assert gene == "RN7SL832P"
                snRNA = gene
                biogenesis = "Pol-III"
            else:
                assert category == "small nuclear RNA"
                if description == 'Homo sapiens RNA component of 7SK nuclear ribonucleoprotein':
                    assert gene == 'RN7SK'
                    snRNA = "7SK"
                    biogenesis = "Pol-III"
                elif description.startswith('Homo sapiens RNA, U1 small nuclear '):
                    assert gene.startswith('RNU1-')
                    snRNA = "U1"
                elif description.startswith('Homo sapiens RNA, variant U1 small nuclear '):
                    assert gene.startswith('RNVU1-')
                    snRNA = "U1"
                elif description.startswith('Homo sapiens RNA, U2 small nuclear '):
                    assert gene.startswith('RNU2-')
                    snRNA = "U2"
                elif description.startswith('Homo sapiens RNA, U4 small nuclear '):
                    assert gene.startswith('RNU4-')
                    snRNA = "U4"
                elif description.startswith('Homo sapiens RNA, U4atac small nuclear '):
                    assert gene == 'RNU4ATAC'
                    snRNA = "U4atac"
                elif description.startswith('Homo sapiens RNA, U5A small nuclear '):
                    assert gene.startswith('RNU5A-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5B small nuclear '):
                    assert gene.startswith('RNU5B-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5D small nuclear '):
                    assert gene.startswith('RNU5D-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5E small nuclear '):
                    assert gene.startswith('RNU5E-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U5F small nuclear '):
                    assert gene.startswith('RNU5F-')
                    snRNA = "U5"
                elif description.startswith('Homo sapiens RNA, U6 small nuclear '):
                    assert gene.startswith('RNU6-')
                    snRNA = "U6"
                elif description.startswith('Homo sapiens RNA, U6atac small nuclear '):
                    assert gene == 'RNU6ATAC'
                    snRNA = "U6atac"
                elif description.startswith('Homo sapiens RNA, U7 small nuclear '):
                    assert gene.startswith('RNU7-')
                    snRNA = "U7"
                elif description == 'Homo sapiens RNA, U11 small nuclear':
                    assert gene == 'RNU11'
                    snRNA = "U11"
                elif description == 'Homo sapiens RNA, U12 small nuclear':
                    assert gene == 'RNU12'
                    snRNA = "U12"
                else:
                    raise Exception("Unknown gene '%s' with description '%s'" % (gene, description))
        if snRNA in ("U1", "U2", "U4", "U4atac", "U5", "U7", "U11", "U12"):
            snRNA = "%s spliceosomal RNA" % snRNA
            biogenesis = "Pol-II"
        elif snRNA in ("U6", "U6atac"):
            snRNA = "%s spliceosomal RNA" % snRNA
            biogenesis = "Pol-III"
        elif biogenesis is None:
            raise Exception("Unknown snRNA %s with description '%s'" % (snRNA, description))
        snrna_annotations[record.id] = (snRNA, biogenesis)
    return snrna_annotations

def read_snorna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "snoRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    snorna_annotations = {}
    for record in records:
        if record.id.startswith("NR_"):
            description = record.description
            if "C/D box 3A" in description:
                snoRNA = "U3"
            elif "C/D box 3B" in description:
                snoRNA = "U3"
            elif "C/D box 3C" in description:
                snoRNA = "U3"
            elif "C/D box 3D" in description:
                snoRNA = "U3"
            elif "C/D box 3E" in description:
                snoRNA = "U3"
            elif "C/D box 3F" in description:
                snoRNA = "U3"
            elif "C/D box 3G" in description:
                snoRNA = "U3"
            elif "C/D box 3H" in description:
                snoRNA = "U3"
            elif "C/D box 3I" in description:
                snoRNA = "U3"
            elif "C/D box 3J" in description:
                snoRNA = "U3"
            elif "C/D box 3K" in description:
                snoRNA = "U3"
            elif "C/D box 3 pseudogene" in description:
                snoRNA = "U3"
            elif "C/D box 118" in description:
                snoRNA = "U8"
            elif "C/D box 13 " in description:
                snoRNA = "U13"
            elif "C/D box 13A" in description:
                snoRNA = "U13"
            elif "C/D box 13B" in description:
                snoRNA = "U13"
            elif "C/D box 13C" in description:
                snoRNA = "U13"
            elif "C/D box 13D" in description:
                snoRNA = "U13"
            elif "C/D box 13E" in description:
                snoRNA = "U13"
            elif "C/D box 13F" in description:
                snoRNA = "U13"
            elif "C/D box 13G" in description:
                snoRNA = "U13"
            elif "C/D box 13H" in description:
                snoRNA = "U13"
            elif "C/D box 13I" in description:
                snoRNA = "U13"
            elif "C/D box 13J" in description:
                snoRNA = "U13"
            else:
                assert record.description.startswith(record.name)
                description = record.description[len(record.name):].strip()
                terms = description.split(", ")
                assert len(terms) == 3
                assert terms[0] in ("Homo sapiens small nucleolar RNA", "Homo sapiens RNA")
                assert terms[2] == "small nucleolar RNA"
                name, symbol = terms[1].rsplit(None, 1)
                assert symbol.startswith("(")
                assert symbol.endswith(")")
                if name == "U105B small nucleolar":
                    snoRNA = "H/ACA box"  # According to snoDB
                elif name == "U105C small nucleolar":
                    snoRNA = "H/ACA box"  # According to snoDB
                else:
                    word1, word2 = name.rsplit(None, 1)
                    assert word1 in ("C/D box", "H/ACA box")
                    snoRNA = word1
        else:
            assert record.id.startswith("ENST")
            try:
                name, description = record.description.split(None, 1)
            except ValueError:
                snoRNA = None
            else:
                snoRNA = None
                if description.startswith("ENSG"):
                    gene, description = description.split(None, 1)
                if description.startswith("H/ACA box Small nucleolar RNA SNORA"):
                    snoRNA = "H/ACA box"
                elif description.startswith("C/D box Small nucleolar RNA SNORD"):
                    snoRNA = "C/D box"
                elif description.startswith("H/ACA box Small nucleolar RNA ACA"):
                    snoRNA = "H/ACA box"
                elif description == "H/ACA box Small nucleolar RNA U109":
                    snoRNA = "H/ACA box"
                elif description.startswith("H/ACA box SNORA"):
                    snoRNA = "H/ACA box"
                elif description.startswith("H/ACA box ACA"):
                    snoRNA = "H/ACA box"
                elif description.startswith("C/D box SNORD"):
                    snoRNA = "C/D box"
                elif description.startswith("C/D box Small nucleolar RNA U2-"):
                    snoRNA = "C/D box"
                elif description == "C/D box Small nucleolar RNA Z40":
                    snoRNA = "C/D box"
                elif description.startswith("C/D box sno"):
                    snoRNA = "C/D box"
                elif description.startswith("C/D box Small nucleolar RNA U83B"):
                    snoRNA = "C/D box"
                elif description.startswith("C/D box Small nucleolar RNA MBII"):
                    snoRNA = "C/D box"
                elif description == "C/D box Small nucleolar RNA U3":
                    snoRNA = "U3"
                elif description == "C/D box U8":
                    snoRNA = "U8"
                elif description == "C/D box Small nucleolar RNA U13":
                    snoRNA = "U13"
                if snoRNA is None:
                    raise Exception("Unknown snoRNA with description '%s'" % record.description)
                    continue
        if snoRNA in ("U3", "U8", "U13"):
            biogenesis = "Pol-II"
        elif snoRNA in ("C/D box", "H/ACA box"):
            biogenesis = "intronic"
        else:
            raise Exception("Unknown snoRNA %s" % snoRNA)
        snoRNA = "%s snoRNA" % snoRNA
        snorna_annotations[record.id] = (snoRNA, biogenesis)
    return snorna_annotations

def read_scarna_annotations():
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "scaRNA.fa"
    path = os.path.join(directory, filename)
    records = SeqIO.parse(path, "fasta")
    scarna_annotations = {}
    for record in records:
        assert record.id.startswith("NR_")
        description = record.description
        if "small Cajal body-specific RNA 2" in description:
            biogenesis = "Pol-II"
        elif "small Cajal body-specific RNA 17" in description:
            biogenesis = "Pol-II"
        else:
            biogenesis = "intronic"
        gene = "small Cajal body-specific RNA"
        scarna_annotations[record.id] = (gene, biogenesis)
    return scarna_annotations

scrna_annotations = read_scrna_annotations()
snrna_annotations = read_snrna_annotations()
snorna_annotations = read_snorna_annotations()
scarna_annotations = read_scarna_annotations()

assembly = 'hg38'
directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes"
filename = "%s.2bit" % assembly
path = os.path.join(directory, assembly, filename)
genome = SeqIO.parse(path, "twobit")

directory = "/osc-fs_home/mdehoon/Data/CASPARs"
subdirectory = os.path.join(directory, dataset, "Mapping")


counts = defaultdict(lambda: {letter: 0 for letter in "ACGTNacgtn"})
filename = "%s.bam" % library
path = os.path.join(subdirectory, filename)
print("Reading %s" % path)
alignments = pysam.AlignmentFile(path, "rb")
for alignment in alignments:
    try:
        target = alignment.get_tag("XT")
    except KeyError:
        assert alignment.is_unmapped
        target = None
    if target is None:
        annotation = "unmapped"
    elif target in annotations:
        annotation = target
    else:
        assert target in ("mRNA", "lncRNA", "gencode", "fantomcat",
                          "novel", "genome", 'TERC', 'MALAT1', 'snhg')
        try:
            annotation = alignment.get_tag("XA")
        except KeyError:
            annotation = "other_intergenic"
    if target in ("snRNA", "scRNA", "snoRNA", "scaRNA"):
        transcripts = alignment.get_tag("XR")
    elif target in ("mRNA", "lncRNA", "gencode", "fantomcat"):
        transcripts = None
    else:
        try:
            transcripts = alignment.get_tag("XR")
        except KeyError:
            transcripts = None
        else:
            raise Exception(target)
    if transcripts is None:
        transcripts = ['-']
    else:
        transcripts = transcripts.split(",")
    if alignment.is_unmapped:
        count = 1
        nucleotide = alignment.query_sequence[0]
    else:
        multimap = alignment.get_tag("NH")
        count = 1 / multimap
        cigar = alignment.cigar
        if alignment.is_secondary:
            assert alignment.query_name == query_name
            if alignment.is_reverse:
                operation, length = cigar[-1]
            else:
                operation, length = cigar[0]
        else:
            query_name = alignment.query_name
            if alignment.is_reverse:
                rna_nucleotide = reverse_complement(alignment.query_sequence[-1])
                operation, length = cigar[-1]
            else:
                rna_nucleotide = alignment.query_sequence[0]
                operation, length = cigar[0]
        assert rna_nucleotide in "ACGTN"
        if operation == pysam.CINS:
            nucleotide = rna_nucleotide.lower()
        else:
            assert operation == pysam.CMATCH
            chromosome = alignment.reference_name
            start = alignment.reference_start
            end = alignment.reference_end
            dna = genome[chromosome][start:end].seq
            if alignment.is_reverse:
                dna_nucleotide = reverse_complement(dna[-1])
            else:
                dna_nucleotide = dna[0]
            if rna_nucleotide == dna_nucleotide.upper():
                nucleotide = rna_nucleotide
            else:
                nucleotide = rna_nucleotide.lower()
    count /= len(transcripts)
    for transcript in transcripts:
        key = (annotation, transcript)
        if annotation in ("RPPH", "RMRP", "yRNA", "vRNA", "snar", "tRNA"):
            category = "Pol-III short RNA"
        elif annotation in ("rRNA", "chrM", "unmapped",
                            "histone",
                            "sense_upstream", "sense_proximal",
                            "sense_distal", "sense_distal_upstream",
                            "prompt", "antisense",
                            "antisense_distal", "antisense_distal_upstream",
                            "FANTOM5_enhancer",
                            "roadmap_enhancer", "roadmap_dyadic",
                            "novel_enhancer_CAGE", "novel_enhancer_HiSeq",
                            "other_intergenic"):
            category = annotation
        elif annotation == "scRNA":
            gene, biogenesis = scrna_annotations[transcript]
            if biogenesis == "intronic":
                category = "intronic short RNA"
            elif biogenesis == "Pol-III":
                category = "Pol-III short RNA"
            else:
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation == "snRNA":
            gene, biogenesis = snrna_annotations[transcript]
            if biogenesis == "Pol-II":
                category = "Pol-II short RNA"
            elif biogenesis == "Pol-III":
                category = "Pol-III short RNA"
            else:
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation == "snoRNA":
            gene, biogenesis = snorna_annotations[transcript]
            if biogenesis == "Pol-II":
                category = "Pol-II short RNA"
            elif biogenesis == "intronic":
                category = "intronic short RNA"
            else:
                print(alignment)
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation == "scaRNA":
            gene, biogenesis = scarna_annotations[transcript]
            if biogenesis == "Pol-II":
                category = "Pol-II short RNA"
            elif biogenesis == "intronic":
                category = "intronic short RNA"
            else:
                raise Exception("Unexpected biogenesis pathway %s" % biogenesis)
        elif annotation in ("pretRNA", "presnRNA", "presnoRNA", "prescaRNA"):
            category = "short RNA precursor"
        else:
            raise Exception("Unknown annotation %s" % annotation)
        counts[category][nucleotide] += count
    if dataset == "MiSeq":
        alignment = next(alignments)

filename = "annotations.%s.txt" % dataset
print("Reading", filename)
stream = open(filename)
line = next(stream)
assert line.startswith("#")
words = line[1:].strip().split("\t")
assert words[0] == "rank"
assert words[1] == "annotation"
assert words[2] == "transcript"
categories = []
for line in stream:
    words = line.strip().split("\t")
    annotation = words[1].strip()
    categories.append(annotation)
stream.close()
categories = list(dict.fromkeys(categories))


filename = "firstnucleotide.%s.%s.txt" % (dataset, library)
print("Writing", filename)
stream = open(filename, 'w')
line = "#annotation\t%s,%s:A,C,G,T,a,c,g,t\n" % (dataset, library)
stream.write(line)
for category in categories:
    count = counts[category]
    line = "%s\t%d,%d,%d,%d,%d,%d,%d,%d\n" % (category,
                                              count['A'],
                                              count['C'],
                                              count['G'],
                                              count['T'],
                                              count['a'],
                                              count['c'],
                                              count['g'],
                                              count['t'],
                                             )
    stream.write(line)
stream.close()
